import os        
import json
import gradio as gr
import functools
from copy import copy
from typing import List, Optional, Union, Callable, Dict, Tuple, Literal
from dataclasses import dataclass
import numpy as np

from lib_controlnet.utils import svg_preprocess, read_image, judge_image_type
from lib_controlnet import (
    global_state,
    external_code,
)
from lib_controlnet.external_code import ControlNetUnit
from lib_controlnet.logging import logger
from lib_controlnet.controlnet_ui.openpose_editor import OpenposeEditor
from lib_controlnet.controlnet_ui.preset import ControlNetPresetUI
from lib_controlnet.controlnet_ui.tool_button import ToolButton
from lib_controlnet.controlnet_ui.photopea import Photopea
from lib_controlnet.controlnet_ui.multi_inputs_gallery import MultiInputsGallery
from lib_controlnet.enums import InputMode, HiResFixOption
from modules import shared, script_callbacks
from modules.ui_components import FormRow
from modules_forge.forge_util import HWC3
from lib_controlnet.enums import (
    InputMode,
    HiResFixOption,
    PuLIDMode,
    ControlNetUnionControlType,
)


@dataclass
class A1111Context:
    """Contains all components from A1111."""

    img2img_batch_input_dir: Optional[gr.components.IOComponent] = None
    img2img_batch_output_dir: Optional[gr.components.IOComponent] = None
    txt2img_submit_button: Optional[gr.components.IOComponent] = None
    img2img_submit_button: Optional[gr.components.IOComponent] = None

    # Slider controls from A1111 WebUI.
    txt2img_w_slider: Optional[gr.components.IOComponent] = None
    txt2img_h_slider: Optional[gr.components.IOComponent] = None
    img2img_w_slider: Optional[gr.components.IOComponent] = None
    img2img_h_slider: Optional[gr.components.IOComponent] = None

    img2img_img2img_tab: Optional[gr.components.IOComponent] = None
    img2img_img2img_sketch_tab: Optional[gr.components.IOComponent] = None
    img2img_batch_tab: Optional[gr.components.IOComponent] = None
    img2img_inpaint_tab: Optional[gr.components.IOComponent] = None
    img2img_inpaint_sketch_tab: Optional[gr.components.IOComponent] = None
    img2img_inpaint_upload_tab: Optional[gr.components.IOComponent] = None

    img2img_inpaint_area: Optional[gr.components.IOComponent] = None
    txt2img_enable_hr: Optional[gr.components.IOComponent] = None

    @property
    def img2img_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]:
        return (
            self.img2img_inpaint_tab,
            self.img2img_inpaint_sketch_tab,
            self.img2img_inpaint_upload_tab,
        )

    @property
    def img2img_non_inpaint_tabs(self) -> Tuple[gr.components.IOComponent]:
        return (
            self.img2img_img2img_tab,
            self.img2img_img2img_sketch_tab,
            self.img2img_batch_tab,
        )

    @property
    def ui_initialized(self) -> bool:
        optional_components = {
            # Optional components are only available after A1111 v1.7.0.
            "img2img_img2img_tab": "img2img_img2img_tab",
            "img2img_img2img_sketch_tab": "img2img_img2img_sketch_tab",
            "img2img_batch_tab": "img2img_batch_tab",
            "img2img_inpaint_tab": "img2img_inpaint_tab",
            "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab",
            "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
        }
        return all(
            c
            for name, c in vars(self).items()
            if name not in optional_components.values()
        )

    def set_component(self, component: gr.components.IOComponent):
        id_mapping = {
            "img2img_batch_input_dir": "img2img_batch_input_dir",
            "img2img_batch_output_dir": "img2img_batch_output_dir",
            "txt2img_generate": "txt2img_submit_button",
            "img2img_generate": "img2img_submit_button",
            "txt2img_width": "txt2img_w_slider",
            "txt2img_height": "txt2img_h_slider",
            "img2img_width": "img2img_w_slider",
            "img2img_height": "img2img_h_slider",
            "img2img_img2img_tab": "img2img_img2img_tab",
            "img2img_img2img_sketch_tab": "img2img_img2img_sketch_tab",
            "img2img_batch_tab": "img2img_batch_tab",
            "img2img_inpaint_tab": "img2img_inpaint_tab",
            "img2img_inpaint_sketch_tab": "img2img_inpaint_sketch_tab",
            "img2img_inpaint_upload_tab": "img2img_inpaint_upload_tab",
            "img2img_inpaint_full_res": "img2img_inpaint_area",
            "txt2img_hr-checkbox": "txt2img_enable_hr",
        }
        elem_id = getattr(component, "elem_id", None)
        # Do not set component if it has already been set.
        # https://github.com/Mikubill/sd-webui-controlnet/issues/2587
        if elem_id in id_mapping and getattr(self, id_mapping[elem_id]) is None:
            setattr(self, id_mapping[elem_id], component)
            logger.debug(f"Setting {elem_id}.")
            logger.debug(
                f"A1111 initialized {sum(c is not None for c in vars(self).values())}/{len(vars(self).keys())}."
            )


class ControlNetUiGroup(object):
    refresh_symbol = "\U0001f504"  # 🔄
    switch_values_symbol = "\U000021C5"  # ⇅
    camera_symbol = "\U0001F4F7"  # 📷
    reverse_symbol = "\U000021C4"  # ⇄
    tossup_symbol = "\u2934"
    trigger_symbol = "\U0001F4A5"  # 💥
    open_symbol = "\U0001F4DD"  # 📝

    tooltips = {
        "🔄": "Refresh",
        "\u2934": "Send dimensions to stable diffusion",
        "💥": "Run preprocessor",
        "📝": "Open new canvas",
        "📷": "Enable webcam",
        "⇄": "Mirror webcam",
    }

    global_batch_input_dir = gr.Textbox(
        label="Controlnet input directory",
        placeholder="Leave empty to use input directory",
        **shared.hide_dirs,
        elem_id="controlnet_batch_input_dir",
    )
    a1111_context = A1111Context()
    # All ControlNetUiGroup instances created.
    all_ui_groups: List["ControlNetUiGroup"] = []

    @property
    def width_slider(self):
        if self.is_img2img:
            return ControlNetUiGroup.a1111_context.img2img_w_slider
        else:
            return ControlNetUiGroup.a1111_context.txt2img_w_slider

    @property
    def height_slider(self):
        if self.is_img2img:
            return ControlNetUiGroup.a1111_context.img2img_h_slider
        else:
            return ControlNetUiGroup.a1111_context.txt2img_h_slider

    def __init__(
        self,
        is_img2img: bool,
        default_unit: external_code.ControlNetUnit,
        photopea: Optional[Photopea] = None,
    ):
        # Whether callbacks have been registered.
        self.callbacks_registered: bool = False
        # Whether the render method on this object has been called.
        self.ui_initialized: bool = False

        self.is_img2img = is_img2img
        self.default_unit = default_unit
        self.photopea = photopea
        self.webcam_enabled = False
        self.webcam_mirrored = False

        # Note: All gradio elements declared in `render` will be defined as member variable.
        # Update counter to trigger a force update of ControlNetUnit.
        # dummy_gradio_update_trigger is useful when a field with no event subscriber available changes.
        # e.g. gr.Gallery, gr.State, etc. After an update to gr.State / gr.Gallery, please increment
        # this counter to trigger a sync update of ControlNetUnit.
        self.dummy_gradio_update_trigger = None
        self.enabled = None
        self.upload_tab = None
        self.image = None
        self.generated_image_group = None
        self.generated_image = None
        self.mask_image_group = None
        self.mask_image = None
        self.batch_tab = None
        self.batch_image_dir = None
        self.batch_upload_tab = None
        self.batch_input_gallery = None
        self.batch_mask_gallery = None
        self.multi_inputs_upload_tab = None
        self.multi_inputs_input_gallery = None
        self.create_canvas = None
        self.canvas_width = None
        self.canvas_height = None
        self.canvas_create_button = None
        self.canvas_cancel_button = None
        self.open_new_canvas_button = None
        self.webcam_enable = None
        self.webcam_mirror = None
        self.send_dimen_button = None
        self.pixel_perfect = None
        self.preprocessor_preview = None
        self.mask_upload = None
        self.type_filter = None
        self.module = None
        self.trigger_preprocessor = None
        self.model = None
        self.refresh_models = None
        self.weight = None
        self.guidance_start = None
        self.guidance_end = None
        self.advanced = None
        self.processor_res = None
        self.threshold_a = None
        self.threshold_b = None
        self.control_mode = None
        self.resize_mode = None
        self.use_preview_as_input = None
        self.openpose_editor = None
        self.preset_panel = None
        self.upload_independent_img_in_img2img = None
        self.image_upload_panel = None
        self.save_detected_map = None
        self.input_mode = gr.State(InputMode.SIMPLE)
        self.hr_option = None
        self.ipa_block_weight = None
        self.ipa_block_weight_selector = None
        self.ipa_block_weight_save_button = None                                                  
        self.batch_image_dir_state = None
        self.output_dir_state = None
        self.advanced_weighting = gr.State(None)

        # Internal states for UI state pasting.
        self.prevent_next_n_module_update = 0
        self.prevent_next_n_slider_value_update = 0

        ControlNetUiGroup.all_ui_groups.append(self)

    def render(self, tabname: str, elem_id_tabname: str) -> None:
        """The pure HTML structure of a single ControlNetUnit. Calling this
        function will populate `self` with all gradio element declared
        in local scope.

        Args:
            tabname:
            elem_id_tabname:

        Returns:
            None
        """
        self.dummy_gradio_update_trigger = gr.Number(value=0, visible=False)
        self.openpose_editor = OpenposeEditor()

        with gr.Group(visible=not self.is_img2img) as self.image_upload_panel:
            self.save_detected_map = gr.Checkbox(value=True, visible=False)
            with gr.Tabs():
                with gr.Tab(label="Single Image") as self.upload_tab:
                    with gr.Row(elem_classes=["cnet-image-row"], equal_height=True):
                        with gr.Group(elem_classes=["cnet-input-image-group"]):
                            self.image = gr.Image(
                                source="upload",
                                brush_radius=20,
                                mirror_webcam=False,
                                type="numpy",
                                tool="sketch",
                                elem_id=f"{elem_id_tabname}_{tabname}_input_image",
                                elem_classes=["cnet-image"],
                                brush_color=shared.opts.img2img_inpaint_mask_brush_color
                                if hasattr(
                                    shared.opts, "img2img_inpaint_mask_brush_color"
                                )
                                else None,
                            )
                            self.image.preprocess = functools.partial(
                                svg_preprocess, preprocess=self.image.preprocess
                            )
                            self.openpose_editor.render_upload()

                        with gr.Group(
                            visible=False, elem_classes=["cnet-generated-image-group"]
                        ) as self.generated_image_group:
                            self.generated_image = gr.Image(
                                value=None,
                                label="Preprocessor Preview",
                                elem_id=f"{elem_id_tabname}_{tabname}_generated_image",
                                elem_classes=["cnet-image"],
                                interactive=True,
                                height=242,
                            )  # Gradio's magic number. Only 242 works.

                            with gr.Group(
                                elem_classes=["cnet-generated-image-control-group"]
                            ):
                                if self.photopea:
                                    self.photopea.render_child_trigger()
                                self.openpose_editor.render_edit()
                                preview_check_elem_id = f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_preview_checkbox"
                                preview_close_button_js = f"document.querySelector('#{preview_check_elem_id} input[type=\\'checkbox\\']').click();"
                                gr.HTML(
                                    value=f"""<a title="Close Preview" onclick="{preview_close_button_js}">Close</a>""",
                                    visible=True,
                                    elem_classes=["cnet-close-preview"],
                                )

                        with gr.Group(
                            visible=False, elem_classes=["cnet-mask-image-group"]
                        ) as self.mask_image_group:
                            self.mask_image = gr.Image(
                                value=None,
                                label="Mask",
                                elem_id=f"{elem_id_tabname}_{tabname}_mask_image",
                                elem_classes=["cnet-mask-image"],
                                interactive=True,
                                brush_radius=20,
                                type="numpy",
                                tool="sketch",
                                brush_color=shared.opts.img2img_inpaint_mask_brush_color
                                if hasattr(
                                    shared.opts, "img2img_inpaint_mask_brush_color"
                                )
                                else None,
                            )

                with gr.Tab(label="Batch Folder") as self.batch_tab:
                    with gr.Row():
                        self.batch_image_dir = gr.Textbox(
                            label="Input Directory",
                            placeholder="Input directory path to the control images.",
                            elem_id=f"{elem_id_tabname}_{tabname}_batch_image_dir",
                        )
                        self.batch_mask_dir = gr.Textbox(
                            label="Mask Directory",
                            placeholder="Mask directory path to the control images.",
                            elem_id=f"{elem_id_tabname}_{tabname}_batch_mask_dir",
                            visible=False,
                        )

                with gr.Tab(label="Batch Upload") as self.batch_upload_tab:
                    with gr.Row():
                        self.batch_input_gallery = MultiInputsGallery()
                        self.batch_mask_gallery = MultiInputsGallery(
                            visible=False,
                            elem_classes=["cnet-mask-gallery-group"]
                        )

                with gr.Tab(label="Multi-Inputs") as self.multi_inputs_upload_tab:
                    with gr.Row():
                        self.multi_inputs_gallery = MultiInputsGallery()

            if self.photopea:
                self.photopea.attach_photopea_output(self.generated_image)

            with gr.Accordion(
                label="Open New Canvas", visible=False
            ) as self.create_canvas:
                self.canvas_width = gr.Slider(
                    label="New Canvas Width",
                    minimum=256,
                    maximum=1024,
                    value=512,
                    step=64,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_width",
                )
                self.canvas_height = gr.Slider(
                    label="New Canvas Height",
                    minimum=256,
                    maximum=1024,
                    value=512,
                    step=64,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_height",
                )
                with gr.Row():
                    self.canvas_create_button = gr.Button(
                        value="Create New Canvas",
                        elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_create_button",
                    )
                    self.canvas_cancel_button = gr.Button(
                        value="Cancel",
                        elem_id=f"{elem_id_tabname}_{tabname}_controlnet_canvas_cancel_button",
                    )

            with gr.Row(elem_classes="controlnet_image_controls"):
                gr.HTML(
                    value="<p>Set the preprocessor to [invert] If your image has white background and black lines.</p>",
                    elem_classes="controlnet_invert_warning",
                )
                self.open_new_canvas_button = ToolButton(
                    value=ControlNetUiGroup.open_symbol,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_open_new_canvas_button",
                    tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.open_symbol],
                )
                self.webcam_enable = ToolButton(
                    value=ControlNetUiGroup.camera_symbol,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_enable",
                    tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.camera_symbol],
                )
                self.webcam_mirror = ToolButton(
                    value=ControlNetUiGroup.reverse_symbol,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_webcam_mirror",
                    tooltip=ControlNetUiGroup.tooltips[
                        ControlNetUiGroup.reverse_symbol
                    ],
                )
                self.send_dimen_button = ToolButton(
                    value=ControlNetUiGroup.tossup_symbol,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_send_dimen_button",
                    tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.tossup_symbol],
                )

        with FormRow(elem_classes=["controlnet_main_options"]):
            self.enabled = gr.Checkbox(
                label="Enable",
                value=self.default_unit.enabled,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_enable_checkbox",
                elem_classes=["cnet-unit-enabled"],
            )
            self.pixel_perfect = gr.Checkbox(
                label="Pixel Perfect",
                value=self.default_unit.pixel_perfect,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_pixel_perfect_checkbox",
            )
            self.preprocessor_preview = gr.Checkbox(
                label="Allow Preview",
                value=False,
                elem_classes=["cnet-allow-preview"],
                elem_id=preview_check_elem_id,
                visible=not self.is_img2img,
            )
            self.mask_upload = gr.Checkbox(
                label="Use Mask",
                value=False,
                elem_classes=["cnet-mask-upload"],
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox",
                visible=not self.is_img2img,
            )
            self.use_preview_as_input = gr.Checkbox(
                label="Preview as Input",
                value=False,
                elem_classes=["cnet-preview-as-input"],
                visible=False,
            )

        with gr.Row(elem_classes="controlnet_img2img_options"):
            if self.is_img2img:
                self.upload_independent_img_in_img2img = gr.Checkbox(
                    label="Upload independent control image",
                    value=False,
                    elem_id=f"{elem_id_tabname}_{tabname}_controlnet_same_img2img_checkbox",
                    elem_classes=["cnet-unit-same_img2img"],
                )
            else:
                self.upload_independent_img_in_img2img = None

        with gr.Row():
            self.union_control_type = gr.Textbox(
                label="Union Control Type",
                value=ControlNetUnionControlType.UNKNOWN.value,
                visible=False,
            )

        with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
            self.type_filter = gr.Radio(
                global_state.get_all_preprocessor_tags(),
                label=f"Control Type",
                value="All",
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_type_filter_radio",
                elem_classes="controlnet_control_type_filter_group",
            )

        with gr.Row(elem_classes=["controlnet_preprocessor_model", "controlnet_row"]):
            self.module = gr.Dropdown(
                global_state.get_all_preprocessor_names(),
                label=f"Preprocessor",
                value=self.default_unit.module,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_dropdown",
            )
            self.trigger_preprocessor = ToolButton(
                value=ControlNetUiGroup.trigger_symbol,
                visible=not self.is_img2img,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_trigger_preprocessor",
                elem_classes=["cnet-run-preprocessor"],
                tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.trigger_symbol],
            )
            self.model = gr.Dropdown(
                global_state.get_all_controlnet_names(),
                label=f"Model",
                value=self.default_unit.model,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_model_dropdown",
            )
            self.refresh_models = ToolButton(
                value=ControlNetUiGroup.refresh_symbol,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_refresh_models",
                tooltip=ControlNetUiGroup.tooltips[ControlNetUiGroup.refresh_symbol],
            )

        with gr.Row(elem_classes=["controlnet_weight_steps", "controlnet_row"]):
            self.weight = gr.Slider(
                label=f"Control Weight",
                value=self.default_unit.weight,
                minimum=0.0,
                maximum=2.0,
                step=0.05,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_weight_slider",
                elem_classes="controlnet_control_weight_slider",
            )
            self.guidance_start = gr.Slider(
                label="Starting Control Step",
                value=self.default_unit.guidance_start,
                minimum=0.0,
                maximum=1.0,
                interactive=True,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_start_control_step_slider",
                elem_classes="controlnet_start_control_step_slider",
            )
            self.guidance_end = gr.Slider(
                label="Ending Control Step",
                value=self.default_unit.guidance_end,
                minimum=0.0,
                maximum=1.0,
                interactive=True,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_ending_control_step_slider",
                elem_classes="controlnet_ending_control_step_slider",
            )

        # advanced options
        with gr.Column(visible=False) as self.advanced:
            self.processor_res = gr.Slider(
                label="Preprocessor resolution",
                value=self.default_unit.processor_res,
                minimum=64,
                maximum=2048,
                visible=False,
                interactive=True,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_preprocessor_resolution_slider",
            )
            self.threshold_a = gr.Slider(
                label="Threshold A",
                value=self.default_unit.threshold_a,
                minimum=64,
                maximum=1024,
                visible=False,
                interactive=True,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_A_slider",
            )
            self.threshold_b = gr.Slider(
                label="Threshold B",
                value=self.default_unit.threshold_b,
                minimum=64,
                maximum=1024,
                visible=False,
                interactive=True,
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_threshold_B_slider",
            )

        with gr.Row(elem_classes=["controlnet_extra_control", "controlnet_row"]):
            self.control_mode = gr.Radio(
                choices=[e.value for e in external_code.ControlMode],
                value=self.default_unit.control_mode.value,
                label="Control Mode",
                elem_id=f"{elem_id_tabname}_{tabname}_controlnet_control_mode_radio",
                elem_classes="controlnet_control_mode_radio",
            )
            self.ipa_block_weight_selector = gr.Dropdown(
                choices=list(external_code.ipa_block_weight_presets.keys()),
                label="[SDXL] IP-A Block Weights",
                value=list(external_code.ipa_block_weight_presets.keys())[0],
                elem_id=f"{elem_id_tabname}_{tabname}_ipa_block_weight",
                allow_custom_value=False,
                visible=False,
            )
            self.ipa_block_weight = gr.Textbox(
                label="Weights",
                visible=False,
                value="",
                placeholder="Preset or custom 11XL weights, e.g.: 0,0,0,0, 0.5, 1,1,1,1,1,1",
            )
            self.ipa_block_weight_save_button = ToolButton(
                value="\U0001f4be",
                elem_classes=["cnet-ipa-preset-save"],
                tooltip="Save Custom Preset",
                visible=False,
            )                                               
            
        IPA_CW_PATH = os.path.join("tmp", "ipa_custom_block_weight.txt")
        def toggle_ipa_controlls(choice):
            if choice == "IP-Adapter":
                return gr.update(visible=True)
            else:
                return gr.update(visible=False)
                
        def handle_dropdown_selection(alias):
            if "Custom" not in alias:
                return external_code.ipa_block_weight_presets.get(alias, "")
            else:
                if os.path.exists(IPA_CW_PATH):
                    with open(IPA_CW_PATH, "r") as file:
                        return file.readline().strip()
                else:
                    return ""
                    
        def fn_save_ipa_custom(value):
            with open(IPA_CW_PATH, "w") as file:
                file.write(value)
            return gr.Dropdown.update(value=list(external_code.ipa_block_weight_presets.keys())[-1])
                
        self.type_filter.change(toggle_ipa_controlls, inputs=self.type_filter, outputs=self.ipa_block_weight)
        self.type_filter.change(toggle_ipa_controlls, inputs=self.type_filter, outputs=self.ipa_block_weight_selector)
        self.type_filter.change(toggle_ipa_controlls, inputs=self.type_filter, outputs=self.ipa_block_weight_save_button)                                                                                                                    
        self.ipa_block_weight_selector.change(handle_dropdown_selection,inputs=self.ipa_block_weight_selector,outputs=self.ipa_block_weight)
        self.ipa_block_weight_save_button.click(
            fn=fn_save_ipa_custom,
            inputs=self.ipa_block_weight,
            outputs=self.ipa_block_weight_selector,
            show_progress=False,
        )
            
        self.resize_mode = gr.Radio(
            choices=[e.value for e in external_code.ResizeMode],
            value=self.default_unit.resize_mode.value,
            label="Resize Mode",
            elem_id=f"{elem_id_tabname}_{tabname}_controlnet_resize_mode_radio",
            elem_classes="controlnet_resize_mode_radio",
            visible=not self.is_img2img,
        )

        self.hr_option = gr.Radio(
            choices=[e.value for e in HiResFixOption],
            value=self.default_unit.hr_option.value,
            label="Hires-Fix Option",
            elem_id=f"{elem_id_tabname}_{tabname}_controlnet_hr_option_radio",
            elem_classes="controlnet_hr_option_radio",
            visible=False,
        )
        
        # self.loopback = gr.Checkbox(
        #     label="[Batch Loopback] Automatically send generated images to this ControlNet unit in batch generation",
        #     value=self.default_unit.loopback,
        #     elem_id=f"{elem_id_tabname}_{tabname}_controlnet_automatically_send_generated_images_checkbox",
        #     elem_classes="controlnet_loopback_checkbox",
        #     visible=False,
        # )

        self.preset_panel = ControlNetPresetUI(
            id_prefix=f"{elem_id_tabname}_{tabname}_"
        )

        self.batch_image_dir_state = gr.State("")
        self.output_dir_state = gr.State("")
        unit_args = (
            self.input_mode,
            self.use_preview_as_input,
            self.batch_image_dir,
            self.batch_mask_dir,
            self.batch_input_gallery.input_gallery,
            self.batch_mask_gallery.input_gallery,
            self.multi_inputs_gallery.input_gallery,
            self.generated_image,
            self.mask_image,
            self.hr_option,
            self.enabled,
            self.module,
            self.model,
            self.weight,
            self.image,
            self.resize_mode,
            self.processor_res,
            self.threshold_a,
            self.threshold_b,
            self.guidance_start,
            self.guidance_end,
            self.pixel_perfect,
            self.control_mode,
            self.advanced_weighting,
            self.ipa_block_weight,
        )

        unit = gr.State(self.default_unit)
        def create_unit(*args):
            return ControlNetUnit.from_dict({
                k: v
                for k, v in zip(vars(ControlNetUnit()).keys(), args)
            })

        for comp in unit_args + (self.dummy_gradio_update_trigger,):
            event_subscribers = []
            if hasattr(comp, "edit"):
                event_subscribers.append(comp.edit)
            elif hasattr(comp, "click"):
                event_subscribers.append(comp.click)
            elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
                event_subscribers.append(comp.release)
            elif hasattr(comp, "change"):
                event_subscribers.append(comp.change)

            if hasattr(comp, "clear"):
                event_subscribers.append(comp.clear)

            for event_subscriber in event_subscribers:
                event_subscriber(
                    fn=create_unit, inputs=list(unit_args), outputs=unit
                )

        (
            ControlNetUiGroup.a1111_context.img2img_submit_button
            if self.is_img2img
            else ControlNetUiGroup.a1111_context.txt2img_submit_button
        ).click(
            fn=create_unit,
            inputs=list(unit_args),
            outputs=unit,
            queue=False,
        )
        self.register_core_callbacks()
        self.ui_initialized = True
        return unit

    def register_send_dimensions(self):
        """Register event handler for send dimension button."""

        def send_dimensions(image):
            def closesteight(num):
                rem = num % 8
                if rem <= 4:
                    return round(num - rem)
                else:
                    return round(num + (8 - rem))

            if image:
                interm = np.asarray(image.get("image"))
                return closesteight(interm.shape[1]), closesteight(interm.shape[0])
            else:
                return gr.Slider.update(), gr.Slider.update()

        self.send_dimen_button.click(
            fn=send_dimensions,
            inputs=[self.image],
            outputs=[self.width_slider, self.height_slider],
            show_progress=False,
        )

    def register_webcam_toggle(self):
        def webcam_toggle():
            self.webcam_enabled = not self.webcam_enabled
            return {
                "value": None,
                "source": "webcam" if self.webcam_enabled else "upload",
                "__type__": "update",
            }

        self.webcam_enable.click(
            webcam_toggle, inputs=None, outputs=self.image, show_progress=False
        )

    def register_webcam_mirror_toggle(self):
        def webcam_mirror_toggle():
            self.webcam_mirrored = not self.webcam_mirrored
            return {"mirror_webcam": self.webcam_mirrored, "__type__": "update"}

        self.webcam_mirror.click(
            webcam_mirror_toggle, inputs=None, outputs=self.image, show_progress=False
        )

    def register_refresh_all_models(self):
        def refresh_all_models():
            global_state.update_controlnet_filenames()
            return gr.Dropdown.update(
                choices=global_state.get_all_controlnet_names(),
            )

        self.refresh_models.click(
            refresh_all_models,
            outputs=[self.model],
            show_progress=False,
        )

    def register_build_sliders(self):
        def build_sliders(module: str, pp: bool):

            logger.debug(
                f"Prevent update slider value: {self.prevent_next_n_slider_value_update}"
            )
            logger.debug(f"Build slider for module: {module} - {pp}")

            preprocessor = global_state.get_preprocessor(module)

            slider_resolution_kwargs = preprocessor.slider_resolution.gradio_update_kwargs.copy()

            if pp:
                slider_resolution_kwargs['visible'] = False

            grs = [
                gr.update(**slider_resolution_kwargs),
                gr.update(**preprocessor.slider_1.gradio_update_kwargs.copy()),
                gr.update(**preprocessor.slider_2.gradio_update_kwargs.copy()),
                gr.update(visible=True),
                gr.update(visible=not preprocessor.do_not_need_model),
                gr.update(visible=not preprocessor.do_not_need_model),
                gr.update(visible=preprocessor.show_control_mode),
            ]

            return grs

        inputs = [
            self.module,
            self.pixel_perfect,
        ]
        outputs = [
            self.processor_res,
            self.threshold_a,
            self.threshold_b,
            self.advanced,
            self.model,
            self.refresh_models,
            self.control_mode,
        ]
        self.module.change(
            build_sliders, inputs=inputs, outputs=outputs, show_progress=False
        )
        self.pixel_perfect.change(
            build_sliders, inputs=inputs, outputs=outputs, show_progress=False
        )

        def filter_selected(k: str):
            logger.debug(f"Prevent update {self.prevent_next_n_module_update}")
            logger.debug(f"Switch to control type {k}")

            filtered_preprocessor_list = global_state.get_filtered_preprocessor_names(k)
            filtered_controlnet_names = global_state.get_filtered_controlnet_names(k)
            default_preprocessor = filtered_preprocessor_list[0]
            default_controlnet_name = filtered_controlnet_names[0]

            if k != 'All':
                if len(filtered_preprocessor_list) > 1:
                    default_preprocessor = filtered_preprocessor_list[1]
                if len(filtered_controlnet_names) > 1:
                    default_controlnet_name = filtered_controlnet_names[1]

            if self.prevent_next_n_module_update > 0:
                self.prevent_next_n_module_update -= 1
                return [
                    gr.Dropdown.update(choices=filtered_preprocessor_list),
                    gr.Dropdown.update(choices=filtered_controlnet_names),
                ]
            else:
                return [
                    gr.Dropdown.update(
                        value=default_preprocessor, choices=filtered_preprocessor_list
                    ),
                    gr.Dropdown.update(
                        value=default_controlnet_name, choices=filtered_controlnet_names
                    ),
                ]

        self.type_filter.change(
            fn=filter_selected,
            inputs=[self.type_filter],
            outputs=[self.module, self.model],
            show_progress=False,
        )

    def register_union_control_type(self):
        def filter_selected(k: str):
            control_type = ControlNetUnionControlType.from_str(k)
            logger.debug(f"Switch to union control type {control_type}")
            return gr.update(value=control_type.value)

        self.type_filter.change(
            fn=filter_selected,
            inputs=[self.type_filter],
            outputs=[self.union_control_type],
            show_progress=False,
        )

    def register_run_annotator(self):
        def run_annotator(image, module, pres, pthr_a, pthr_b, t2i_w, t2i_h, pp, rm):
            if image is None:
                return (
                    gr.update(value=None, visible=True),
                    gr.update(),
                    *self.openpose_editor.update(""),
                )

            img = HWC3(image["image"])
            mask = HWC3(image["mask"])

            if not (mask > 5).any():
                mask = None

            preprocessor = global_state.get_preprocessor(module)

            if pp:
                pres = external_code.pixel_perfect_resolution(
                    img,
                    target_H=t2i_h,
                    target_W=t2i_w,
                    resize_mode=external_code.resize_mode_from_value(rm),
                )

            class JsonAcceptor:
                def __init__(self) -> None:
                    self.value = ""

                def accept(self, json_dict: dict) -> None:
                    self.value = json.dumps(json_dict)

            json_acceptor = JsonAcceptor()

            logger.info(f"Preview Resolution = {pres}")

            def is_openpose(module: str):
                return "openpose" in module

            # Only openpose preprocessor returns a JSON output, pass json_acceptor
            # only when a JSON output is expected. This will make preprocessor cache
            # work for all other preprocessors other than openpose ones. JSON acceptor
            # instance are different every call, which means cache will never take
            # effect.
            # TODO: Maybe we should let `preprocessor` return a Dict to alleviate this issue?
            # This requires changing all callsites though.
            result = preprocessor(
                input_image=img,
                resolution=pres,
                slider_1=pthr_a,
                slider_2=pthr_b,
                input_mask=mask,
                json_pose_callback=json_acceptor.accept
                if is_openpose(module)
                else None,
            )

            is_image = judge_image_type(result)

            if not is_image:
                result = img

            result = external_code.visualize_inpaint_mask(result)
            return (
                # Update to `generated_image`
                gr.update(value=result, visible=True, interactive=False),
                # preprocessor_preview
                gr.update(value=True),
                # openpose editor
                *self.openpose_editor.update(json_acceptor.value),
            )

        self.trigger_preprocessor.click(
            fn=run_annotator,
            inputs=[
                self.image,
                self.module,
                self.processor_res,
                self.threshold_a,
                self.threshold_b,
                self.width_slider,
                self.height_slider,
                self.pixel_perfect,
                self.resize_mode,
            ],
            outputs=[
                self.generated_image,
                self.preprocessor_preview,
                *self.openpose_editor.outputs(),
            ],
        )

    def register_shift_preview(self):
        def shift_preview(is_on):
            return (
                # generated_image
                gr.update() if is_on else gr.update(value=None),
                # generated_image_group
                gr.update(visible=is_on),
                # use_preview_as_input,
                gr.update(visible=False),  # Now this is automatically managed
                # download_pose_link
                gr.update() if is_on else gr.update(value=None),
                # modal edit button
                gr.update() if is_on else gr.update(visible=False),
            )

        self.preprocessor_preview.change(
            fn=shift_preview,
            inputs=[self.preprocessor_preview],
            outputs=[
                self.generated_image,
                self.generated_image_group,
                self.use_preview_as_input,
                self.openpose_editor.download_link,
                self.openpose_editor.modal,
            ],
            show_progress=False,
        )

    def register_create_canvas(self):
        self.open_new_canvas_button.click(
            lambda: gr.Accordion.update(visible=True),
            inputs=None,
            outputs=self.create_canvas,
            show_progress=False,
        )
        self.canvas_cancel_button.click(
            lambda: gr.Accordion.update(visible=False),
            inputs=None,
            outputs=self.create_canvas,
            show_progress=False,
        )

        def fn_canvas(h, w):
            return np.zeros(shape=(h, w, 3), dtype=np.uint8), gr.Accordion.update(
                visible=False
            )

        self.canvas_create_button.click(
            fn=fn_canvas,
            inputs=[self.canvas_height, self.canvas_width],
            outputs=[self.image, self.create_canvas],
            show_progress=False,
        )

    def register_img2img_same_input(self):
        def fn_same_checked(x):
            return [
                gr.update(value=None),
                gr.update(value=None),
                gr.update(value=False, visible=x),
            ] + [gr.update(visible=x)] * 3

        self.upload_independent_img_in_img2img.change(
            fn_same_checked,
            inputs=self.upload_independent_img_in_img2img,
            outputs=[
                self.image,
                self.batch_image_dir,
                self.preprocessor_preview,
                self.image_upload_panel,
                self.trigger_preprocessor,
                self.resize_mode,
            ],
            show_progress=False,
        )

    def register_shift_crop_input_image(self):
        return

    def register_shift_hr_options(self):
        ControlNetUiGroup.a1111_context.txt2img_enable_hr.change(
            fn=lambda checked: gr.update(visible=checked),
            inputs=[ControlNetUiGroup.a1111_context.txt2img_enable_hr],
            outputs=[self.hr_option],
            show_progress=False,
        )

    def register_shift_upload_mask(self):
        """Controls whether the upload mask input should be visible."""
        def on_checkbox_click(checked: bool, canvas_height: int, canvas_width: int):
            if not checked:
                # Clear mask inputs if unchecked.
                return (
                    # Single mask upload.
                    gr.update(visible=False),
                    gr.update(value=None),
                    # Batch mask upload dir.
                    gr.update(value=None, visible=False),
                    # Multi mask upload gallery.
                    gr.update(visible=False),
                    gr.update(value=None)
                )
            else:
                # Init an empty canvas the same size as the generation target.
                empty_canvas = np.zeros(shape=(canvas_height, canvas_width, 3), dtype=np.uint8)
                return (
                    # Single mask upload.
                    gr.update(visible=True),
                    gr.update(value=empty_canvas),
                    # Batch mask upload dir.
                    gr.update(visible=True),
                    # Multi mask upload gallery.
                    gr.update(visible=True),
                    gr.update(),
                )

        self.mask_upload.change(
            fn=on_checkbox_click,
            inputs=[self.mask_upload, self.height_slider, self.width_slider],
            outputs=[
                self.mask_image_group,
                self.mask_image,
                self.batch_mask_dir,
                self.batch_mask_gallery.group,
                self.batch_mask_gallery.input_gallery,
            ],
            show_progress=False,
        )

        if self.upload_independent_img_in_img2img is not None:
            self.upload_independent_img_in_img2img.change(
                fn=lambda checked: (
                    # Uncheck `upload_mask` when not using independent input.
                    gr.update(visible=False, value=False)
                    if not checked
                    else gr.update(visible=True)
                ),
                inputs=[self.upload_independent_img_in_img2img],
                outputs=[self.mask_upload],
                show_progress=False,
            )

    def register_sync_batch_dir(self):
        def determine_batch_dir(batch_dir, fallback_dir, fallback_fallback_dir):
            if batch_dir:
                return batch_dir
            elif fallback_dir:
                return fallback_dir
            else:
                return fallback_fallback_dir

        batch_dirs = [
            self.batch_image_dir,
            ControlNetUiGroup.global_batch_input_dir,
            ControlNetUiGroup.a1111_context.img2img_batch_input_dir,
        ]
        for batch_dir_comp in batch_dirs:
            subscriber = getattr(batch_dir_comp, "blur", None)
            if subscriber is None:
                continue
            subscriber(
                fn=determine_batch_dir,
                inputs=batch_dirs,
                outputs=[self.batch_image_dir_state],
                queue=False,
            )

        ControlNetUiGroup.a1111_context.img2img_batch_output_dir.blur(
            fn=lambda a: a,
            inputs=[ControlNetUiGroup.a1111_context.img2img_batch_output_dir],
            outputs=[self.output_dir_state],
            queue=False,
        )

    def register_clear_preview(self):
        def clear_preview(x):
            if x:
                logger.info("Preview as input is cancelled.")
            return gr.update(value=False), gr.update(value=None)

        for comp in (
            self.pixel_perfect,
            self.module,
            self.image,
            self.processor_res,
            self.threshold_a,
            self.threshold_b,
            self.upload_independent_img_in_img2img,
        ):
            event_subscribers = []
            if hasattr(comp, "edit"):
                event_subscribers.append(comp.edit)
            elif hasattr(comp, "click"):
                event_subscribers.append(comp.click)
            elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
                event_subscribers.append(comp.release)
            elif hasattr(comp, "change"):
                event_subscribers.append(comp.change)
            if hasattr(comp, "clear"):
                event_subscribers.append(comp.clear)
            for event_subscriber in event_subscribers:
                event_subscriber(
                    fn=clear_preview,
                    inputs=self.use_preview_as_input,
                    outputs=[self.use_preview_as_input, self.generated_image],
                    show_progress=False
                )

    def register_multi_images_upload(self):
        """Register callbacks on merge tab multiple images upload."""
        trigger_dict = dict(
            fn=lambda n: gr.update(value=n + 1),
            inputs=[self.dummy_gradio_update_trigger],
            outputs=[self.dummy_gradio_update_trigger],
        )
        self.batch_input_gallery.register_callbacks(change_trigger=trigger_dict)
        self.batch_mask_gallery.register_callbacks(change_trigger=trigger_dict)
        self.multi_inputs_gallery.register_callbacks(change_trigger=trigger_dict)

    def register_core_callbacks(self):
        """Register core callbacks that only involves gradio components defined
        within this ui group."""
        self.register_webcam_toggle()
        self.register_webcam_mirror_toggle()
        self.register_refresh_all_models()
        self.register_build_sliders()
        self.register_shift_preview()
        self.register_create_canvas()
        self.register_clear_preview()
        self.register_multi_images_upload()
        self.openpose_editor.register_callbacks(
            self.generated_image,
            self.use_preview_as_input,
            self.model,
        )
        assert self.type_filter is not None
        self.preset_panel.register_callbacks(
            self,
            self.type_filter,
            *[
                getattr(self, key)
                for key in external_code.ControlNetUnit.infotext_fields()
            ],
        )
        if self.is_img2img:
            self.register_img2img_same_input()

    def register_callbacks(self):
        """Register callbacks that involves A1111 context gradio components."""
        # Prevent infinite recursion.
        if self.callbacks_registered:
            return

        self.callbacks_registered = True
        self.register_send_dimensions()
        self.register_run_annotator()
        self.register_sync_batch_dir()
        self.register_shift_upload_mask()
        if self.is_img2img:
            self.register_shift_crop_input_image()
        else:
            self.register_shift_hr_options()

    @staticmethod
    def register_input_mode_sync(ui_groups: List["ControlNetUiGroup"]):
        """
        - ui_group.input_mode should be updated when user switch tabs.
        - Loopback checkbox should only be visible if at least one ControlNet unit
        is set to batch mode.

        Argument:
            ui_groups: All ControlNetUiGroup instances defined in current Script context.

        Returns:
            None
        """
        if not ui_groups:
            return

        for ui_group in ui_groups:
            batch_fn = lambda: InputMode.BATCH
            simple_fn = lambda: InputMode.SIMPLE
            merge_fn = lambda: InputMode.MERGE
            for input_tab, fn in (
                (ui_group.upload_tab, simple_fn),
                (ui_group.batch_tab, batch_fn),
                (ui_group.batch_upload_tab, batch_fn),
                (ui_group.multi_inputs_upload_tab, merge_fn),
            ):
                # Sync input_mode.
                input_tab.select(
                    fn=fn,
                    inputs=[],
                    outputs=[ui_group.input_mode],
                    show_progress=False,
                )

    @staticmethod
    def reset():
        ControlNetUiGroup.a1111_context = A1111Context()
        ControlNetUiGroup.all_ui_groups = []

    @staticmethod
    def try_register_all_callbacks():
        unit_count = shared.opts.data.get("control_net_unit_count", 3)
        all_unit_count = unit_count * 2  # txt2img + img2img.
        if (
            # All A1111 components ControlNet units care about are all registered.
            ControlNetUiGroup.a1111_context.ui_initialized
            and all_unit_count == len(ControlNetUiGroup.all_ui_groups)
            and all(
                g.ui_initialized and (not g.callbacks_registered)
                for g in ControlNetUiGroup.all_ui_groups
            )
        ):
            for ui_group in ControlNetUiGroup.all_ui_groups:
                ui_group.register_callbacks()

            ControlNetUiGroup.register_input_mode_sync(
                [g for g in ControlNetUiGroup.all_ui_groups if g.is_img2img]
            )
            ControlNetUiGroup.register_input_mode_sync(
                [g for g in ControlNetUiGroup.all_ui_groups if not g.is_img2img]
            )
            logger.info("ControlNet UI callback registered.")

    @staticmethod
    def on_after_component(component, **_kwargs):
        """Register the A1111 component."""
        if getattr(component, "elem_id", None) == "img2img_batch_inpaint_mask_dir":
            ControlNetUiGroup.global_batch_input_dir.render()
            return

        ControlNetUiGroup.a1111_context.set_component(component)
        ControlNetUiGroup.try_register_all_callbacks()
